/**
  * Copyright 2018 Yahoo Inc.
  * Licensed under the terms of the Apache 2.0 license.
  * Please see LICENSE file in the project root for terms.
  */
package com.yahoo.tensorflowonspark

import java.nio._

import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.tensorflow._

import scala.collection.JavaConversions._
import scala.collection.mutable.ListBuffer

/**
  * Singleton object which will cache the saved_model, graph, and session per executor JVM.
  */
object TFModel {
  var modelDir: String = _
  var model: SavedModelBundle = _
  var graph: Graph = _
  var sess: Session = _
}

/**
  * Represents a trained TensorFlow model, allowing callers to transform a Dataset with predictions generated by the model.
  * This currently only supports models that fit in the memory of a single executor, and it executes the model as a
  * single-node TensorFlow application on each of the executors.  The model must be "exported" via TensorFlow's SavedModelBuilder.
  *
  * @param uid unique identifier for this Model class.
  */
class TFModel(override val uid: String) extends Model[TFModel] with TFParams {

  def this() = this(Identifiable.randomUID("tfmodel"))

  override def copy(extra: ParamMap): TFModel = defaultCopy(extra)

  /**
    * Helper method to transform a list of Rows into a map of (column_name -> Tensor).
    *
    * @param batch List of N Rows containing M Columns, where each column represents data for one Tensor.
    * @param schema schema of the incoming Rows.
    * @return a Map of M entries of (column_name -> Tensor), where each Tensor has N items in the 0-th dimension.
    */
  def batch2tensors(batch: Seq[Row], schema: StructType): Map[String, Tensor[_]] = {
    var tensors = new scala.collection.mutable.HashMap[String, Tensor[_]]()

    for (i <- schema.fields.indices) {
      val colType = schema.fields(i).dataType
      val colName = schema.fields(i).name

      colType match {
        // scalar types
        case t: sql.types.BinaryType =>
          val arr: Array[Array[Byte]] = batch.map(row => row.getAs[Array[Byte]](i)).toArray
          tensors.put(colName, Tensors.create(arr))
        case t: sql.types.BooleanType =>
          val arr: Array[Boolean] = batch.map(row => row.getBoolean(i)).toArray
          tensors.put(colName, Tensors.create(arr))
        case t: sql.types.DoubleType =>
          val arr: Array[Double] = batch.map(row => row.getDouble(i)).toArray
          tensors.put(colName, Tensors.create(arr))
        case t: sql.types.FloatType =>
          val arr: Array[Float] = batch.map(row => row.getFloat(i)).toArray
          tensors.put(colName, Tensors.create(arr))
        case t: sql.types.IntegerType =>
          val arr: Array[Int] = batch.map(row => row.getInt(i)).toArray
          tensors.put(colName, Tensors.create(arr))
        case t: sql.types.LongType =>
          val arr: Array[Long] = batch.map(row => row.getLong(i)).toArray
          tensors.put(colName, Tensors.create(arr))
        case t: sql.types.StringType =>
          val arr: Array[Array[Byte]] = batch.map(row => row.getAs[String](i).getBytes()).toArray
          tensors.put(colName, Tensors.create(arr))
        case t: sql.types.ArrayType =>
          // array types
          val baseType = t.elementType
          baseType match {
            case sql.types.BinaryType =>
              val arr: Array[Array[Array[Byte]]] = batch.map(row => row.getList[Array[Byte]](i).toList.toArray).toArray
              tensors.put(colName, Tensors.create(arr))
            case sql.types.BooleanType =>
              val arr: Array[Array[Boolean]] = batch.map(row => row.getList[Boolean](i).toList.toArray).toArray
              tensors.put(colName, Tensors.create(arr))
            case sql.types.DoubleType =>
              val arr: Array[Array[Double]] = batch.map(row => row.getList[Double](i).toList.toArray).toArray
              tensors.put(colName, Tensors.create(arr))
            case sql.types.FloatType =>
              val arr: Array[Array[Float]] = batch.map(row => row.getList[Float](i).toList.toArray).toArray
              tensors.put(colName, Tensors.create(arr))
            case sql.types.IntegerType =>
              val arr: Array[Array[Int]] = batch.map(row => row.getList[Int](i).toList.toArray).toArray
              tensors.put(colName, Tensors.create(arr))
            case sql.types.LongType =>
              val arr: Array[Array[Long]] = batch.map(row => row.getList[Long](i).toList.toArray).toArray
              tensors.put(colName, Tensors.create(arr))
            case sql.types.StringType =>
              val arr: Array[Array[Array[Byte]]] = batch.map(row => row.getList[String](i).map(_.getBytes()).toList.toArray).toArray
              tensors.put(colName, Tensors.create(arr))
            case unsupportedType =>
              throw new Exception(s"Unsupported base type in array: $unsupportedType")
          }
        case unsupportedType =>
          throw new Exception(s"Unsupported column type: $unsupportedType")
      }
    }
    tensors.toMap
  }

  /**
    * Helper method to transform a list of Tensors into a list of Rows.
    * @param tensors Sequence of M TensorFlow Tensors, each with N items in the 0-th dimension.
    * @return List of N Rows with M Columns.
    */
  def tensors2batch(tensors: Seq[Tensor[_]]): List[Row] = {
    // all output tensors must have same cardinality in the 0-dimension
    assert(tensors.map(t => t.shape()(0)).distinct.size == 1)
    val numRows = tensors.head.shape()(0).toInt

    // convert Tensors to list of (type, shape, flattened nio buffers)
    var tensorArrays = ListBuffer.empty[(DataType, Array[Long], Buffer)]
    for (t <- tensors) {
      val dtype = t.dataType()
      dtype match {
        case DataType.BOOL =>
          val buf = ByteBuffer.allocate(t.numElements)
          t.writeTo(buf)
          buf.flip()
          tensorArrays = tensorArrays :+ (dtype, t.shape, buf)
        case DataType.DOUBLE =>
          val buf = DoubleBuffer.allocate(t.numElements)
          t.writeTo(buf)
          buf.flip()
          tensorArrays = tensorArrays :+ (dtype, t.shape, buf)
        case DataType.FLOAT =>
          val buf = FloatBuffer.allocate(t.numElements)
          t.writeTo(buf)
          buf.flip()
          tensorArrays = tensorArrays :+ (dtype, t.shape, buf)
        case DataType.INT32 =>
          val buf = IntBuffer.allocate(t.numElements)
          t.writeTo(buf)
          buf.flip()
          tensorArrays = tensorArrays :+ (dtype, t.shape, buf)
        case DataType.INT64 =>
          val buf = LongBuffer.allocate(t.numElements)
          t.writeTo(buf)
          buf.flip()
          tensorArrays = tensorArrays :+ (dtype, t.shape, buf)
        case DataType.STRING =>
          val buf = ByteBuffer.allocate(t.numBytes())
          t.writeTo(buf)
          buf.flip()
          tensorArrays = tensorArrays :+ (dtype, t.shape, buf)
        case DataType.UINT8 =>
          val buf = ByteBuffer.allocate(t.numBytes())
          t.writeTo(buf)
          tensorArrays = tensorArrays :+ (dtype, t.shape, buf)
        case unsupportedType =>
          throw new Exception(s"Unsupported output Tensor type: $unsupportedType")
      }
    }

    var result = ListBuffer.empty[Row]
    for (i <- 0 until numRows) {
      val slices = tensorArrays.map { case (dtype, shape, buf) =>
        // 0-th dimension should be the number of rows
        if (shape.length == 1) {
          // scalar values
          dtype match {
            case DataType.BOOL =>
              buf.asInstanceOf[ByteBuffer].get() == 1.toByte
            case DataType.DOUBLE =>
              buf.asInstanceOf[DoubleBuffer].get()
            case DataType.FLOAT =>
              buf.asInstanceOf[FloatBuffer].get()
            case DataType.INT32 =>
              buf.asInstanceOf[IntBuffer].get()
            case DataType.INT64 =>
              buf.asInstanceOf[LongBuffer].get()
            case DataType.STRING =>
              buf.asInstanceOf[ByteBuffer].get(Array.ofDim[Byte](buf.remaining()))
            case DataType.UINT8 =>
              buf.asInstanceOf[ByteBuffer].get(Array.ofDim[Byte](buf.remaining()))
            case unsupportedType =>
              None
          }
        } else {
          // array values
          val numElements = shape.drop(1).map(_.toInt).product
          dtype match {
            case DataType.BOOL =>
              val byteBuf = buf.asInstanceOf[ByteBuffer]
              val byteArr = Array.ofDim[Byte](numElements)
              byteBuf.get(byteArr, 0, numElements)
              byteArr.map ( b => b == 1.toByte)
            case DataType.DOUBLE =>
              val doubleBuf = buf.asInstanceOf[DoubleBuffer]
              val doubleArr = Array.ofDim[Double](numElements)
              doubleBuf.get(doubleArr, 0, numElements)
              doubleArr
            case DataType.FLOAT =>
              val floatBuf = buf.asInstanceOf[FloatBuffer]
              val floatArr = Array.ofDim[Float](numElements)
              floatBuf.get(floatArr, 0, numElements)
              floatArr
            case DataType.INT32 =>
              val intBuf = buf.asInstanceOf[IntBuffer]
              val intArr = Array.ofDim[Int](numElements)
              intBuf.get(intArr, 0, numElements)
              intArr
            case DataType.INT64 =>
              val longBuf = buf.asInstanceOf[LongBuffer]
              val longArr = Array.ofDim[Long](numElements)
              longBuf.get(longArr, 0, numElements)
              longArr
            case DataType.STRING =>
              val byteBuf = buf.asInstanceOf[ByteBuffer]
              val byteArr = Array.ofDim[Byte](buf.remaining())
              byteBuf.get(byteArr)
            case DataType.UINT8 =>
              val byteBuf = buf.asInstanceOf[ByteBuffer]
              val byteArr = Array.ofDim[Byte](buf.remaining())
              byteBuf.get(byteArr)
            case unsupportedType =>
              Array()
          }
        }
      }
      result += Row(slices:_*)
    }
    result.toList
  }

  /**
    * Transforms a Dataset of input data into an output DataFrame of predictions.
    * Note: input columns are dropped during transform.
    */
  override def transform(dataset: Dataset[_]): DataFrame = {
    val spark = dataset.sparkSession

    val inputColumns = this.getInputMapping.keys.toSeq
    val inputTensorNames = this.getInputMapping.values
    val outputTensorNames = this.getOutputMapping.keys.toSeq

    val inputDF = dataset.select(inputColumns.head, inputColumns.tail: _*)
    val inputSchema = inputDF.schema
    val outputSchema = transformSchema(inputSchema)

    val outputRDD = inputDF.rdd.mapPartitions { iter: Iterator[Row] =>
      if (TFModel.model == null || TFModel.modelDir != this.getModel) {
        // load model into a per-executor singleton reference, if needed.
        TFModel.modelDir = this.getModel
        TFModel.model = SavedModelBundle.load(this.getModel, this.getTag)
        TFModel.graph = TFModel.model.graph
        TFModel.sess = TFModel.model.session
      }

      iter.grouped(this.getBatchSize).flatMap { batch =>
        // get input batch of Rows and convert to list of input Tensors
        val inputTensors = batch2tensors(batch, inputSchema)

        var runner = TFModel.sess.runner()

        // feed input tensors
        for ((name, tensor) <- inputTensors) {
          runner = runner.feed(this.getInputMapping(name), tensor)
        }
        // fetch output tensors
        for (name <- outputTensorNames) {
          runner = runner.fetch(name)
        }

        // run the graph
        val outputTensors = runner.run()

        assert(outputTensors.map(_.shape).map(s => if (s.isEmpty) 0L else s.apply(0)).distinct.size == 1,
          "Cardinality of output tensors must match")

        // convert the list of output Tensors to a batch of output Rows
        tensors2batch(outputTensors)
      }
    }

    spark.createDataFrame(outputRDD, outputSchema)
  }

  override def transformSchema(schema: StructType): StructType = {
    val model = SavedModelBundle.load(this.getModel, this.getTag)
    val g = model.graph

    val fields = this.getOutputMapping.map { case (tensorName, columnName) =>
      val op = g.operation(tensorName)
      // if a requested tensorName is not found, dump all operations in the graph to aid user and throw exception.
      if (op == null) {
        g.operations().foreach(println)
        throw new Exception(s"op $tensorName is null")
      }
      val output: Output[_] = op.output(0)
      makeStructField(columnName, output)
    }

    // ignore input schema, since we're dropping all inputs for output predictions
    StructType(fields.toArray)
  }

  /**
    * Returns the Spark DataFrame StructField for a given output Tensor.
    * @param columnName name of the Spark DataFrame column.
    * @param output OutputTensor
    * @return
    */
  private def makeStructField(columnName: String, output: Output[_]): StructField = {
    def dtype2sqlType(dataType: DataType, shape: Shape): sql.types.DataType = {
      dataType match {
        // TODO: support other DataTypes
        // TODO: support tensors with more than 2 dimensions
        case DataType.BOOL => if (shape.numDimensions == 1) sql.types.BooleanType else sql.types.ArrayType(sql.types.BooleanType)
        case DataType.DOUBLE => if (shape.numDimensions == 1) sql.types.DoubleType else sql.types.ArrayType(sql.types.DoubleType)
        case DataType.FLOAT => if (shape.numDimensions == 1) sql.types.FloatType else sql.types.ArrayType(sql.types.FloatType)
        case DataType.INT32 => if (shape.numDimensions == 1) sql.types.IntegerType else sql.types.ArrayType(sql.types.IntegerType)
        case DataType.INT64 => if (shape.numDimensions == 1) sql.types.LongType else sql.types.ArrayType(sql.types.LongType)
        case DataType.UINT8 => if (shape.numDimensions == 1) sql.types.BinaryType else sql.types.ArrayType(sql.types.BinaryType)
        case DataType.STRING => if (shape.numDimensions == 1) sql.types.StringType else sql.types.ArrayType(sql.types.StringType)
        case unsupportedType => throw new Exception(s"Unsupported type: $unsupportedType")
      }
    }
    StructField(columnName, dtype2sqlType(output.dataType, output.shape), nullable=false)
  }

}
